{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "greek-engineering",
   "metadata": {},
   "source": [
    "# Using Pre Generated Primitive for Operations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "voluntary-exercise",
   "metadata": {},
   "source": [
    "For doing any private multi-party computation in SyMPC we generally require primitives. In SyMPC we can generate these primitives in two ways:\n",
    "\n",
    "- Generate Primitive on the fly (during the computation).\n",
    "- Generate Primitive before the computation.\n",
    "\n",
    "Author: \n",
    "- Anubhav Singh - [Twitter](https://twitter.com/aanurraj) - [GitHub](https://github.com/aanurraj)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "selective-palestinian",
   "metadata": {},
   "source": [
    "In this tutorial we will see how to generate this primitives before performing any computation and what advantage it has over the primitive generated on the fly.\n",
    "\n",
    "We will start with basic import for doing a multi-party computation. You can read more about these imports in introduction notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "thermal-saying",
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "\n",
    "sy.logger.remove()\n",
    "import torch\n",
    "\n",
    "from sympc.session import Session\n",
    "from sympc.session import SessionManager\n",
    "from sympc.tensor import MPCTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dense-death",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the virtual machines that would be use in the computation\n",
    "alice_vm = sy.VirtualMachine(name=\"alice\")\n",
    "bob_vm = sy.VirtualMachine(name=\"bob\")\n",
    "\n",
    "# Get clients from each VM\n",
    "alice = alice_vm.get_root_client()\n",
    "bob = bob_vm.get_root_client()\n",
    "\n",
    "parties = [alice, bob]\n",
    "\n",
    "# Setup the session for the computation\n",
    "session = Session(parties=parties)\n",
    "SessionManager.setup_mpc(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "surprising-fundamentals",
   "metadata": {},
   "source": [
    "Next, we will declare PyTorch tensors and will share these tensors among parties and perform all the computations on these tensors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "optical-geology",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MPCTensor]\n",
      "Shape: torch.Size([1000, 1000])\n",
      "Requires Grad: False\n",
      "\t| <VirtualMachineClient: alice Client> -> ShareTensorPointer\n",
      "\t| <VirtualMachineClient: bob Client> -> ShareTensorPointer\n",
      "[MPCTensor]\n",
      "Shape: torch.Size([1000, 1000])\n",
      "Requires Grad: False\n",
      "\t| <VirtualMachineClient: alice Client> -> ShareTensorPointer\n",
      "\t| <VirtualMachineClient: bob Client> -> ShareTensorPointer\n"
     ]
    }
   ],
   "source": [
    "# Define the private values to shares\n",
    "x_secret = torch.randn((1000, 1000))\n",
    "y_secret = torch.randn((1000, 1000))\n",
    "\n",
    "# Share the secret between the parties\n",
    "x = x_secret.share(parties=parties)\n",
    "y = y_secret.share(parties=parties)\n",
    "\n",
    "print(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "continuous-certificate",
   "metadata": {},
   "source": [
    "Lets checkout how much time a matmul operations takes to compute when primitives are generated on the fly. ⏱"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "golden-memorial",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "common-overview",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38.517215967178345\n"
     ]
    }
   ],
   "source": [
    "t1 = time.time()\n",
    "x @ y\n",
    "t2 = time.time()\n",
    "print(t2-t1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "embedded-formula",
   "metadata": {},
   "source": [
    "The above result may vary depending on the configuration of the machine on which the operation is performed.\n",
    "Next, we will see how to generate the primitive before hand."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "recovered-cheat",
   "metadata": {},
   "source": [
    "### Pre Generating Primitives"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "jewish-singles",
   "metadata": {},
   "source": [
    "For pre generating the primitives we need to import the CryptoPrimitiveProvider, lets do that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "centered-ozone",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympc.store.crypto_primitive_provider import CryptoPrimitiveProvider"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fluid-thirty",
   "metadata": {},
   "source": [
    "We have to follow two steps to generate the primitives.\n",
    "- Log the primitive we need.\n",
    "- Generate the primitives using the above generated logs.\n",
    "\n",
    "The logs are generally the meta information needed in the computation which mainly comprises of shapes of the primitive required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "internal-charge",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'beaver_matmul': [({'a_shape': (1000, 1000), 'b_shape': (1000, 1000)}, {'a_shape': (1000, 1000), 'b_shape': (1000, 1000), 'nr_parties': 2})]}\n"
     ]
    }
   ],
   "source": [
    "CryptoPrimitiveProvider.start_logging()\n",
    "x @ y\n",
    "log = CryptoPrimitiveProvider.stop_logging()\n",
    "\n",
    "print(log)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "peaceful-margin",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the primitives using the logs generated in the previous step.\n",
    "CryptoPrimitiveProvider.generate_primitive_from_dict(primitive_log=log, session=session)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cognitive-queensland",
   "metadata": {},
   "source": [
    "We can now calculate the time required to do the same computation using the pre-generated primitives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "personalized-barrel",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26.111281156539917\n"
     ]
    }
   ],
   "source": [
    "t1 = time.time()\n",
    "x @ y\n",
    "t2 = time.time()\n",
    "print(t2-t1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "pharmaceutical-jonathan",
   "metadata": {},
   "source": [
    "We can clearly see the difference in the calculation time. This pre generated way comes handy when we want to do long and complex calculations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "worthy-syndication",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
